clc; clear all;

load('h_list.mat');

[N, Length, Rep] = size(h_list);

K = 1;

N_ports = 4;

SumRate = zeros(Length,Rep/K);
SumRate_GPR = zeros(Length,Rep/K);
SumRate_exp = zeros(Length,Rep/K);
SumRate_LMMSE = zeros(Length,Rep/K);
SumRate_ML = zeros(Length,Rep/K);
SumRate_OMP = zeros(Length,Rep/K);


for pp = 1:Length
    % sigma2 = 10.^(-SNR(pp)/10);

    sigma2 = 10.^(-0/10);

    for rp = 1:Rep/K
        H = squeeze(h_list(:,pp,[(rp-1)*K+1:1:rp*K]));    % Real channel
        H_GPR = squeeze(h_hat_GPR_list(:,pp,[(rp-1)*K+1:1:rp*K]));
        H_exp = squeeze(h_hat_exp_list(:,pp,[(rp-1)*K+1:1:rp*K]));
        H_LMMSE = squeeze(h_hat_LMMSE_list(:,pp,[(rp-1)*K+1:1:rp*K]));
        H_ML = squeeze(h_hat_ML_list(:,pp,[(rp-1)*K+1:1:rp*K]));
        H_OMP = squeeze(h_hat_OMP_list(:,pp,[(rp-1)*K+1:1:rp*K]));    % OMP estimated channel

        %% Ideal
        W = 0*H;
        % W = H*pinv(H'*H + sigma2*eye(K));
        [~,idx] = maxk(abs(H),N_ports);
        W(idx) = H(idx);
        W = W/norm(W,'fro');
        %% GPR
        W_GPR = 0*H_GPR;
        [~,idx] = maxk(abs(H_GPR),N_ports);
        W_GPR(idx) = H_GPR(idx);
        % W_GPR =  H_GPR*pinv(H_GPR'*H_GPR+ sigma2*eye(K));
        W_GPR = W_GPR/norm(W_GPR,'fro');
        %% exp
        W_exp = 0*H_exp;
        [~,idx] = maxk(abs(H_exp),N_ports);
        W_exp(idx) = H_exp(idx);
        % W_exp =  H_exp*pinv(H_exp'*H_exp+ sigma2*eye(K));
        W_exp = W_exp/norm(W_exp,'fro');
        %% OMP
        W_OMP = 0*H_OMP;
        [~,idx] = maxk(abs(H_OMP),N_ports);
        W_OMP(idx) = H_OMP(idx);
        % W_OMP =  H_OMP*pinv(H_OMP'*H_OMP+ sigma2*eye(K));
        W_OMP = W_OMP/norm(W_OMP,'fro');
        %% ML
        W_ML = 0*H_ML;
        [~,idx] = maxk(abs(H_ML),N_ports);
        W_ML(idx) = H_ML(idx);
        % W_ML =  H_ML*pinv(H_ML'*H_ML+ sigma2*eye(K));
        W_ML = W_ML/norm(W_ML,'fro');
        %% LMMSE
        W_LMMSE = 0*H_LMMSE;
        [~,idx] = maxk(abs(H_LMMSE),N_ports);
        W_LMMSE(idx) = H_LMMSE(idx);
       %  W_LMMSE =  H_LMMSE*pinv(H_LMMSE'*H_LMMSE+ sigma2*eye(K));
        W_LMMSE = W_LMMSE/norm(W_LMMSE,'fro');

        SumRate(pp,rp) = SR_calculate(W,H,sigma2);
        SumRate_GPR(pp,rp) = SR_calculate(W_GPR,H,sigma2);
        SumRate_exp(pp,rp) = SR_calculate(W_exp,H,sigma2);
        SumRate_LMMSE(pp,rp) = SR_calculate(W_LMMSE,H,sigma2);
        SumRate_ML(pp,rp) = SR_calculate(W_ML,H,sigma2);
        SumRate_OMP(pp,rp) = SR_calculate(W_OMP,H,sigma2);

    end
end

SumRate = mean(SumRate,2);
SumRate_GPR = mean(SumRate_GPR,2);
SumRate_exp = mean(SumRate_exp,2);
SumRate_LMMSE = mean(SumRate_LMMSE,2);
SumRate_ML = mean(SumRate_ML,2);
SumRate_OMP = mean(SumRate_OMP,2);

save('Capacity_SNR.mat','SNR','SumRate','SumRate_GPR','SumRate_exp','SumRate_LMMSE','SumRate_ML','SumRate_OMP');

C = linspecer(5);

figure;
box on; grid on; hold on;
plot(SNR,SumRate,'--k','LineWidth',1.5);
plot(SNR,SumRate_GPR,'-d','LineWidth',1.5,'Color',C(1,:));
plot(SNR,SumRate_exp,'-r>','LineWidth',1.5,'Color',C(2,:));
plot(SNR,SumRate_LMMSE,'-p','LineWidth',1.5,'Color',C(5,:));
plot(SNR,SumRate_ML,'-o','LineWidth',1.5,'Color',C(3,:));
plot(SNR,SumRate_OMP,'-s','LineWidth',1.5,'Color',C(4,:));

legend('Oracle MMSE','Proposed S-BAR (${\bf \Sigma}_{\rm cov}$)',...
    'Proposed S-BAR (${\bf \Sigma}_{\rm exp}$)','SeLMMSE', ...
    'FAS-ML','FAS-OMP','FontSize',11,'Interpreter','latex');

xlabel('SNR for channel estimation (dB)','Interpreter','latex','FontSize',14);
ylabel('Capacity (bps/Hz)','Interpreter','latex','FontSize',14);

%% Functions
function SR = SR_calculate(W,H,sigma2)
K = size(W,2);
SR = 0;
for k = 1:K
    h_k = H(:,k);
    w_k = W(:,k);
    US = abs(h_k'*w_k)^2;
    IUI = 0;
    for j = 1:K
        h_j = H(:,j);
        IUI = IUI + abs(h_j'*w_k)^2;
    end
    IUI = IUI - US;
    gamma_k = US/(IUI + sigma2);
    SR = SR + log2(1+gamma_k);
end
end



